{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6adfa9fa-b760-49d8-be45-65fb67c5ab48",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Mistral 7B deployment guide\n",
    "In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.\n",
    "\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "* SageMaker access\n",
    "## Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59dcc3aa-2cf6-44fc-95d8-d3fc819b5593",
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e45b3aab-7136-4c94-8c60-20a40191d08f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sagemaker\n",
    "from sagemaker.djl_inference.model import DJLModel\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b0f3f23-33ef-4a39-98fc-cbe03e30fcd6",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 2: Start building SageMaker endpoint\n",
    "In this step, we will build SageMaker endpoint from scratch\n",
    "\n",
    "### Getting the container image URI (optional)\n",
    "\n",
    "Check out available images: [Large Model Inference available DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "765de906-9747-4d69-aa78-0b0a359bd57f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Choose a specific version of LMI image directly:\n",
    "# image_uri = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbd9f984-fc41-47e9-b17e-3b96fb75f49e",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Create SageMaker model\n",
    "\n",
    "Here we are using [LMI PySDK](https://sagemaker.readthedocs.io/en/stable/frameworks/djl/using_djl.html) to create the model.\n",
    "\n",
    "Checkout more [configuration options](https://docs.djl.ai/docs/serving/serving/docs/lmi/deployment_guide/configurations.html#environment-variable-configurations)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53bea8c5-8fbf-463b-b430-87d6c32b8806",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_id = \"mistralai/Mistral-7B-v0.1\" # model will be download form Huggingface hub\n",
    "hf_token = os.getenv(\"HF_TOKEN\", \"hf_XXXXXXXXXXX\")    # use your HF_TOKEN to access this model\n",
    "\n",
    "env = {\n",
    "    \"TENSOR_PARALLEL_DEGREE\": \"1\",          # use 1 GPU, set to \"max\" to use all GPUs on the instance\n",
    "    \"HF_TOKEN\": hf_token,\n",
    "    \"OPTION_ROLLING_BATCH\": \"auto\",         # optional, enabled by default\n",
    "    \"OPTION_TRUST_REMOTE_CODE\": \"true\",\n",
    "}\n",
    "\n",
    "model = DJLModel(\n",
    "            model_id=model_id,\n",
    "            env=env,\n",
    "            role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "787cf1c3-ea9f-4a83-84d6-480bbb07f57b",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Create SageMaker endpoint\n",
    "\n",
    "You need to specify the instance to use and endpoint names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08b153e2-3eb6-4882-aeb5-81d7efbb7a9a",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "instance_type = \"ml.g5.2xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(\"lmi-model\")\n",
    "\n",
    "predictor = model.deploy(initial_instance_count=1,\n",
    "             instance_type=instance_type,\n",
    "             endpoint_name=endpoint_name,\n",
    "            )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5e60c73-f3d8-43ef-9261-23677f03d5cb",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Step 3: Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e21c0102-132c-4c18-8627-223592578c86",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    {\"inputs\": \"tell me a story of the little red riding hood\", \"parameters\": {\"max_new_tokens\":128, \"do_sample\":True}}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa8fba6e-aa0b-44b1-82d5-9d369e38e8bb",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "#### benchmark"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "790bf3f4-9398-41dd-8648-b6e43053b3e7",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "This can be done outside this notebook, in a bash shell terminal. The [awscurl](https://github.com/deepjavalibrary/djl-serving/blob/master/awscurl/README.md) here is a benchmark tool, obtainable from \n",
    "\n",
    "```\n",
    "curl -O https://publish.djl.ai/awscurl/awscurl && chmod +x awscurl\n",
    "```\n",
    "\n",
    "See: [Benchmarking your Endpoint](https://docs.djl.ai/docs/serving/serving/docs/lmi/deployment_guide/benchmarking-your-endpoint.html) for more detail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae51100a-1f65-4da9-b2ea-6a5565374b64",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sh\n",
    "curl -O https://publish.djl.ai/awscurl/awscurl && chmod +x awscurl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "497898f1-266d-44e2-ae9e-9b9dfb0cee74",
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoint_url=f\"https://runtime.sagemaker.{session._region_name}.amazonaws.com/endpoints/{endpoint_name}/invocations\"\n",
    "endpoint_url"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ced57091-9962-4f6b-9fc5-6f07207d4867",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "!TOKENIZER=codellama/CodeLlama-34b-hf ./awscurl -c 4 -N 10 -n sagemaker {endpoint_url} \\\n",
    "  -H \"Content-type: application/json\" \\\n",
    "  -d '{{\"inputs\":\"The new movie that got Oscar this year\",\"parameters\":{{\"max_new_tokens\":256, \"do_sample\":true, \"temperature\":0.8, \"top_k\":5}}}}' \\\n",
    "  -t"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f69dbef-d858-46ed-a163-0a15c65056ca",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Clean up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d56c6f57-4554-4ea1-ad58-f0a12e294d44",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "session.delete_endpoint(endpoint_name)\n",
    "session.delete_endpoint_config(endpoint_name)\n",
    "model.delete_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45dfd0e3-313d-4d7f-ae32-9399ed524551",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
